import pandas as pd
import numpy as np
import statsmodels.api as sm
import sys

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        #print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
        seMultiplier=getSEMultiplier(dataset,pbvarb)
        #seReg = regEstimate(dataset,pbvarb,outcome)[1]
        seChen = seReg*seMultiplier
        return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
        return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

individual2014 = pd.read_csv('/REDACTED/data/final/individualBISG2014_full_final.csv')

# In Table 1 of the paper (Estimated Audit Rate Disparity), audit categories are defined slightly differently:
# There, we assign a correspondence audit to any taxpayer that has an employee_code containing 5, and assign a
# non-correspondence audit to all other audited taxpayers. Here, we do not consider taxpayers with 'mixed'
# values of employee_code, or a value equal to [9]. Instead we just consider those with an employee_code equal to
# [1], [2], or [5], and assign a correspondence audit to taxpayers that have an employee_code of [5]. This
# leads to 672 fewer audits than we consider in the variable aud_no_research_audits, and 1,313 fewer audits than we 
# consider in the correspondence/non-correspondence columns of table 1. 1,037 taxpayers are classified as
# having correspondence audits in table 1 that are not considered to have correspondence audits here.

# assign audit indicators for different types, compare to alternative audit classifications
individual2014['field_aud'] = np.where((individual2014['employee_code'] == '[1]') & (individual2014['aud_no_research_audits'] == 1), 1, 0)
individual2014.value_counts(['field_aud', 'aud_no_research_audits'])

individual2014['office_aud'] = np.where((individual2014['employee_code'] == '[2]') & (individual2014['aud_no_research_audits'] == 1), 1, 0)
individual2014['correspondence_aud'] = np.where((individual2014['employee_code'] == '[5]') & (individual2014['aud_no_research_audits'] == 1), 1, 0)

individual2014 = individual2014[individual2014.predicted_prob_black.notnull()]

individual2014.value_counts(['correspondence_aud', 'corr_aud'])

len(individual2014[individual2014['corr_aud'] == 1]) + len(individual2014[individual2014['non_corr_aud'] == 1])
len(individual2014[individual2014['field_aud'] == 1]) + len(individual2014[individual2014['office_aud'] == 1]) + len(individual2014[individual2014['correspondence_aud'] == 1])
len(individual2014[individual2014['aud_no_research_audits'] == 1])
#number of corr_aud + non_corr_aud: 781,268
#number of field_aud + office_aud + correspondence_aud: 779,955
#number of aud_no_research_audits: 780,627

# write out probabilistic results
sys.stdout = open("/REDACTED/audit_burdens_prob.txt", "w")


print("Audit Rate (pp)(Correspondence, Office, Field, Overall)")
round(100*len(individual2014[individual2014['correspondence_aud'] == 1])/len(individual2014), 2)
round(100*len(individual2014[individual2014['office_aud'] == 1])/len(individual2014), 2)
round(100*len(individual2014[individual2014['field_aud'] == 1])/len(individual2014), 2)
round(100*((len(individual2014[individual2014['field_aud'] == 1])/len(individual2014)) + (len(individual2014[individual2014['office_aud'] == 1])/len(individual2014)) + (len(individual2014[individual2014['correspondence_aud'] == 1])/len(individual2014))), 2)


# gettaxpayer_idg shares of each audit type

aud_data = individual2014[individual2014.aud_no_research_audits == 1]

# removing [9] for now
valid_codes = ['[1]', '[2]', '[5]']
aud_data = aud_data[aud_data.employee_code.isin(valid_codes)] 
# removes 672 entries

field_share = len(aud_data[aud_data.field_aud == 1]) / len(aud_data)
office_share = len(aud_data[aud_data.office_aud == 1]) / len(aud_data)
correspondence_share = len(aud_data[aud_data.correspondence_aud == 1]) / len(aud_data)

print("\nShare of Total Audits (Correspondence, Office, Field, Overall)")
round(correspondence_share, 2)
round(office_share, 2)
round(field_share, 2)
round(correspondence_share + office_share + field_share, 2)

# from Guyton et al
correspondence_avg_hours = 30
office_avg_hours = 38
field_avg_hours = 34
overall_hours = correspondence_share*correspondence_avg_hours + office_share*office_avg_hours + field_share*field_avg_hours

print("\nTaxpayer Hours (Correspondence, Office, Field, Overall)")
correspondence_avg_hours
office_avg_hours
field_avg_hours
round(overall_hours, 2)

# the following values are inflation-adjusted using the calculator at https://www.bls.gov/data/inflation_calculator.htm
correspondence_avg_tcc = 643 # adjusted from 580
office_avg_tcc = 1717 # adjusted from 1,550
field_avg_tcc = 4431 # adjusted from 4,000
overall_tcc = correspondence_share*correspondence_avg_tcc + office_share*office_avg_tcc + field_share*field_avg_tcc

print("\nTaxpayer Compliance Cost (Correspondence, Office, Field, Overall)")
correspondence_avg_tcc
office_avg_tcc
field_avg_tcc
round(overall_tcc, 2)

# the following values come from AGI_Exam_Stats.xlsx, from Tom Hertz
correspondence_avg_pen_int = 320.71
office_avg_pen_int = 1580.77
field_avg_pen_int = 6434.52
overall_pen_int = correspondence_share*correspondence_avg_pen_int + office_share*office_avg_pen_int + field_share*field_avg_pen_int

print("\nPenalties and Interest (Correspondence, Office, Field, Overall)")
correspondence_avg_pen_int
office_avg_pen_int
field_avg_pen_int
round(overall_pen_int, 2)


# the following values come from AGI_Exam_Stats.xlsx, from Tom Hertz
correspondence_avg_tax_assmnt = 5252.56
office_avg_tax_assmnt = 7130.46
field_avg_tax_assmnt = 24960.56
overall_tax_assmnt = correspondence_share*correspondence_avg_tax_assmnt + office_share*office_avg_tax_assmnt + field_share*field_avg_tax_assmnt

print("\nAssessed Taxes (Correspondence, Office, Field, Overall)")
correspondence_avg_tax_assmnt
office_avg_tax_assmnt
field_avg_tax_assmnt
round(overall_tax_assmnt, 2)

# calculate audit rate disparities, use outputs to calculate other disparities
prob_field_aud_disp = chenEstimate(individual2014, pbvarb = 'predicted_prob_black', outcome = 'field_aud')[0]
prob_office_aud_disp = chenEstimate(individual2014, pbvarb = 'predicted_prob_black', outcome = 'office_aud')[0]
prob_correspondence_aud_disp = chenEstimate(individual2014, pbvarb = 'predicted_prob_black', outcome = 'correspondence_aud')[0]
prob_overall_aud_disp = (correspondence_share*prob_correspondence_aud_disp + office_share*prob_office_aud_disp + field_share*prob_field_aud_disp)*100

print("\nDisparity (Audit Rate, pp) (Correspondence, Office, Field, Overall)")
round(prob_correspondence_aud_disp*100, 2)
round(prob_office_aud_disp*100, 2)
round(prob_field_aud_disp*100, 2)
round(prob_overall_aud_disp, 2)


print("\nDisparity (Hours) (Correspondence, Office, Field, Overall)")
prob_correspondence_hours_disp = (prob_correspondence_aud_disp)*correspondence_avg_hours
round(prob_correspondence_hours_disp, 3)
prob_office_hours_disp = (prob_office_aud_disp)*office_avg_hours
round(prob_office_hours_disp, 3)
prob_field_hours_disp = (prob_field_aud_disp)*field_avg_hours
round(prob_field_hours_disp, 3)
prob_overall_hours_disp = correspondence_share*prob_correspondence_hours_disp + office_share*prob_office_hours_disp + field_share*prob_field_hours_disp
round(prob_overall_hours_disp, 3)

print("\nDisparity (Compliance Cost) (Correspondence, Office, Field, Overall)")
prob_correspondence_tcc_disp = (prob_correspondence_aud_disp)*correspondence_avg_tcc
round(prob_correspondence_tcc_disp, 2)
prob_office_tcc_disp = (prob_office_aud_disp)*office_avg_tcc
round(prob_office_tcc_disp, 2)
prob_field_tcc_disp = (prob_field_aud_disp)*field_avg_tcc
round(prob_field_tcc_disp, 2)
prob_overall_tcc_disp = correspondence_share*prob_correspondence_tcc_disp + office_share*prob_office_tcc_disp + field_share*prob_field_tcc_disp
round(prob_overall_tcc_disp, 2)

print("\nDisparity (Penalties and Interest) (Correspondence, Office, Field, Overall)")
prob_correspondence_pen_int_disp = (prob_correspondence_aud_disp)*correspondence_avg_pen_int
round(prob_correspondence_pen_int_disp, 2)
prob_office_pen_int_disp = (prob_office_aud_disp)*office_avg_pen_int
round(prob_office_pen_int_disp, 2)
prob_field_pen_int_disp = (prob_field_aud_disp)*field_avg_pen_int
round(prob_field_pen_int_disp, 2)
prob_overall_pen_int_disp = correspondence_share*prob_correspondence_pen_int_disp + office_share*prob_office_pen_int_disp + field_share*prob_field_pen_int_disp
round(prob_overall_pen_int_disp, 2)

print("\nDisparity (Assessed Taxes) (Correspondence, Office, Field, Overall)")
prob_correspondence_tax_assmnt_disp = (prob_correspondence_aud_disp)*correspondence_avg_tax_assmnt
round(prob_correspondence_tax_assmnt_disp, 2)
prob_office_tax_assmnt_disp = (prob_office_aud_disp)*office_avg_tax_assmnt
round(prob_office_tax_assmnt_disp, 2)
prob_field_tax_assmnt_disp = (prob_field_aud_disp)*field_avg_tax_assmnt
round(prob_field_tax_assmnt_disp, 2)
prob_overall_tax_assmnt_disp = correspondence_share*prob_correspondence_tax_assmnt_disp + office_share*prob_office_tax_assmnt_disp + field_share*prob_field_tax_assmnt_disp
round(prob_overall_tax_assmnt_disp, 2)

sys.stdout = sys.__stdout__

# write out linear results
sys.stdout = open("/REDACTED/audit_burdens_lin.txt", "w")

# first 6 code chunks are the same as probabilistic code above
print("Audit Rate (pp)(Correspondence, Office, Field, Overall)")
round(100*len(individual2014[individual2014['correspondence_aud'] == 1])/len(individual2014), 2)
round(100*len(individual2014[individual2014['office_aud'] == 1])/len(individual2014), 2)
round(100*len(individual2014[individual2014['field_aud'] == 1])/len(individual2014), 2)
round(100*((len(individual2014[individual2014['field_aud'] == 1])/len(individual2014)) + (len(individual2014[individual2014['office_aud'] == 1])/len(individual2014)) + (len(individual2014[individual2014['correspondence_aud'] == 1])/len(individual2014))), 2)

print("\nShare of Total Audits (Correspondence, Office, Field, Overall)")
round(correspondence_share, 2)
round(office_share, 2)
round(field_share, 2)
round(correspondence_share + office_share + field_share, 2)

print("\nTaxpayer Hours (Correspondence, Office, Field, Overall)")
correspondence_avg_hours
office_avg_hours
field_avg_hours
round(overall_hours, 2)

print("\nTaxpayer Compliance Cost (Correspondence, Office, Field, Overall)")
correspondence_avg_tcc
office_avg_tcc
field_avg_tcc
round(overall_tcc, 2)

print("\nPenalties and Interest (Correspondence, Office, Field, Overall)")
correspondence_avg_pen_int
office_avg_pen_int
field_avg_pen_int
round(overall_pen_int, 2)

print("\nAssessed Taxes (Correspondence, Office, Field, Overall)")
correspondence_avg_tax_assmnt
office_avg_tax_assmnt
field_avg_tax_assmnt
round(overall_tax_assmnt, 2)

# calculate audit rate disparities, use outputs to calculate other disparities
lin_field_aud_disp = regEstimate(individual2014, pbvarb = 'predicted_prob_black', outcome = 'field_aud')[0]
lin_office_aud_disp = regEstimate(individual2014, pbvarb = 'predicted_prob_black', outcome = 'office_aud')[0]
lin_correspondence_aud_disp = regEstimate(individual2014, pbvarb = 'predicted_prob_black', outcome = 'correspondence_aud')[0]
lin_overall_aud_disp = (correspondence_share*lin_correspondence_aud_disp + office_share*lin_office_aud_disp + field_share*lin_field_aud_disp)*100
print("\nDisparity (Audit Rate, pp) (Correspondence, Office, Field, Overall)")
round(lin_correspondence_aud_disp*100, 2)
round(lin_office_aud_disp*100, 2)
round(lin_field_aud_disp*100, 2)
round(lin_overall_aud_disp, 2)


print("\nDisparity (Hours) (Correspondence, Office, Field, Overall)")
lin_correspondence_hours_disp = (lin_correspondence_aud_disp)*correspondence_avg_hours
round(lin_correspondence_hours_disp, 3)
lin_office_hours_disp = (lin_office_aud_disp)*office_avg_hours
round(lin_office_hours_disp, 3)
lin_field_hours_disp = (lin_field_aud_disp)*field_avg_hours
round(lin_field_hours_disp, 3)
lin_overall_hours_disp = correspondence_share*lin_correspondence_hours_disp + office_share*lin_office_hours_disp + field_share*lin_field_hours_disp
round(lin_overall_hours_disp, 3)

print("\nDisparity (Compliance Cost) (Correspondence, Office, Field, Overall)")
lin_correspondence_tcc_disp = (lin_correspondence_aud_disp)*correspondence_avg_tcc
round(lin_correspondence_tcc_disp, 2)
lin_office_tcc_disp = (lin_office_aud_disp)*office_avg_tcc
round(lin_office_tcc_disp, 2)
lin_field_tcc_disp = (lin_field_aud_disp)*field_avg_tcc
round(lin_field_tcc_disp, 2)
lin_overall_tcc_disp = correspondence_share*lin_correspondence_tcc_disp + office_share*lin_office_tcc_disp + field_share*lin_field_tcc_disp
round(lin_overall_tcc_disp, 2)

print("\nDisparity (Penalties and Interest)  (Correspondence, Office, Field, Overall)")
lin_correspondence_pen_int_disp = (lin_correspondence_aud_disp)*correspondence_avg_pen_int
round(lin_correspondence_pen_int_disp, 2)
lin_office_pen_int_disp = (lin_office_aud_disp)*office_avg_pen_int
round(lin_office_pen_int_disp, 2)
lin_field_pen_int_disp = (lin_field_aud_disp)*field_avg_pen_int
round(lin_field_pen_int_disp, 2)
lin_overall_pen_int_disp = correspondence_share*lin_correspondence_pen_int_disp + office_share*lin_office_pen_int_disp + field_share*lin_field_pen_int_disp
round(lin_overall_pen_int_disp, 2)

print("\nDisparity (Assessed Taxes)  (Correspondence, Office, Field, Overall)")
lin_correspondence_tax_assmnt_disp = (lin_correspondence_aud_disp)*correspondence_avg_tax_assmnt
round(lin_correspondence_tax_assmnt_disp, 2)
lin_office_tax_assmnt_disp = (lin_office_aud_disp)*office_avg_tax_assmnt
round(lin_office_tax_assmnt_disp, 2)
lin_field_tax_assmnt_disp = (lin_field_aud_disp)*field_avg_tax_assmnt
round(lin_field_tax_assmnt_disp, 2)
lin_overall_tax_assmnt_disp = correspondence_share*lin_correspondence_tax_assmnt_disp + office_share*lin_office_tax_assmnt_disp + field_share*lin_field_tax_assmnt_disp
round(lin_overall_tax_assmnt_disp, 2)

sys.stdout = sys.__stdout__